{
 "cells": [
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "**This notebook generates videos of the uncertainties predicted by PixLoc on sequences of images.**"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 1,
   "metadata": {},
   "outputs": [],
   "source": [
    "import torch\n",
    "from torch.nn.functional import interpolate\n",
    "import numpy as np\n",
    "from tqdm import tqdm\n",
    "from collections import defaultdict\n",
    "from datetime import datetime\n",
    "import matplotlib.pyplot as plt\n",
    "\n",
    "from pixloc.utils.data import Paths\n",
    "from pixloc.localization.feature_extractor import FeatureExtractor\n",
    "from pixloc.pixlib.utils.experiments import load_experiment\n",
    "from pixloc.pixlib.datasets.view import read_image\n",
    "from pixloc.visualization.viz_2d import plot_images, add_text\n",
    "from pixloc.visualization.animation import VideoWriter, display_video\n",
    "from pixloc.settings import DATA_PATH, LOC_PATH"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# Datasets\n",
    "Run one of the following cells to parse the list of images in a sequence."
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### 1) CMU"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Sequences: ['2010-09-01', '2010-09-15', '2010-10-01', '2010-10-19', '2010-10-26', '2010-11-03', '2010-11-12', '2010-11-22', '2010-12-21', '2011-03-04', '2011-07-28']\n"
     ]
    }
   ],
   "source": [
    "from pixloc.run_CMU import default_paths, experiment\n",
    "\n",
    "paths = default_paths.add_prefixes(DATA_PATH/'CMU', LOC_PATH/'CMU')\n",
    "date2query = defaultdict(list)\n",
    "for slice_ in range(2, 26):\n",
    "    queries = sorted(paths.interpolate(slice=slice_).query_images.glob('*.jpg'))\n",
    "    for q in queries:\n",
    "        ts = int(q.name.split('_')[-1].rstrip('us.jpg'))\n",
    "        date = datetime.utcfromtimestamp(ts//1000000)\n",
    "        date2query[date.strftime('%Y-%m-%d')].append((ts, q))\n",
    "date2query = dict(date2query)\n",
    "\n",
    "print('Sequences:', sorted(date2query.keys()))\n",
    "seq = '2010-12-21'\n",
    "cam = 'c1'\n",
    "images = [p for _, p in date2query[seq] if cam in p.name]\n",
    "# plt.plot([ts for ts, p in date2query[seq] if cam in p.name]);"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### 2) Cambridge"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 150,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Sequences: {'seq3': 87, 'seq2': 161, 'seq1': 301, 'seq9': 122, 'seq5': 68, 'seq6': 87, 'seq8': 126, 'seq4': 56, 'seq7': 76}\n"
     ]
    }
   ],
   "source": [
    "from pixloc.run_Cambridge import default_paths, experiment\n",
    "\n",
    "paths = default_paths.add_prefixes(DATA_PATH/'Cambridge', LOC_PATH/'Cambridge')\n",
    "paths = paths.interpolate(scene='OldHospital')\n",
    "all_sequences = {}\n",
    "for seq in paths.query_images.glob('seq*'):\n",
    "    all_sequences[f'{seq.name}'] = sorted(seq.iterdir())\n",
    "    \n",
    "print('Sequences:', {s: len(l) for s, l in all_sequences.items()})\n",
    "cam = 'seq1'\n",
    "images = all_sequences[cam]"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### 3) Aachen v1.1"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 152,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Sequences: {'gopro': 1046, 'nexus_1': 171, 'nexus_3': 17, 'nexus_7': 242, 'nexus_6': 84, 'nexus_2': 43, 'nexus_5': 203, 'nexus_8': 71, 'nexus_4': 494}\n"
     ]
    }
   ],
   "source": [
    "from pixloc.run_Aachen import default_paths, experiment\n",
    "\n",
    "paths = default_paths.add_prefixes(DATA_PATH/'Aachen_v1.1', LOC_PATH/'Aachen_v1.1')\n",
    "root = paths.reference_images / 'sequences'\n",
    "all_sequences = {}\n",
    "all_sequences['gopro'] = sorted((root/'gopro3_undistorted').glob('*.png'))\n",
    "for seq in (root/'nexus4_sequences').iterdir():\n",
    "    all_sequences[f'nexus_{seq.name.split(\"_\")[1]}'] = sorted(seq.iterdir())\n",
    "\n",
    "print('Sequences:', {s: len(l) for s, l in all_sequences.items()})\n",
    "cam = 'gopro'\n",
    "images = all_sequences[cam]"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### 4) RobotCar"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 118,
   "metadata": {},
   "outputs": [],
   "source": [
    "from pixloc.run_RobotCar import default_paths, experiment\n",
    "\n",
    "paths = default_paths.add_prefixes(DATA_PATH/'RobotCar', LOC_PATH/'RobotCar')\n",
    "condition = 'snow'\n",
    "cam = 'right'\n",
    "images = sorted((paths.query_images/condition/cam).glob('*.jpg'))"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# Inference\n",
    "\n",
    "Choose one of the two models pretrained on CMU or MegaDepth:"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "[09/24/2021 15:40:42 pixloc.pixlib.utils.experiments INFO] Loading checkpoint checkpoint_best.tar\n"
     ]
    }
   ],
   "source": [
    "device = torch.device('cuda')\n",
    "pipeline = load_experiment(experiment).to(device).eval()\n",
    "net = FeatureExtractor(pipeline.extractor, device, {'resize': 1024})"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "- Set `preview=True` to visualize the predictions for a few images in the sequence.\n",
    "- Set `preview=False` to dump a video of the predictions."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 6,
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "100%|██████████████████████████████████████████████████| 100/100 [00:31<00:00,  3.15it/s]\n",
      "[09/24/2021 15:41:31 pixloc.visualization.animation INFO] Running ffmpeg.\n"
     ]
    }
   ],
   "source": [
    "preview = True\n",
    "\n",
    "video = 'uncertainties.mp4'\n",
    "args = dict(cmaps='turbo', dpi=50)\n",
    "if preview:\n",
    "    num = slice(10)\n",
    "    skip = slice(None, None, 100)\n",
    "else:\n",
    "    num = slice(100)\n",
    "    skip = slice(None, None, 10)\n",
    "    writer = VideoWriter('./tmp')\n",
    "\n",
    "for i, path in enumerate(tqdm(images[skip][num])):\n",
    "    image = read_image(path)\n",
    "    _, _, confs = net(image)\n",
    "    shape = confs[0][0].shape\n",
    "    confs = [interpolate(c[None], size=shape, mode='bilinear', align_corners=False)[0] for c in confs]\n",
    "    confs = [c.cpu().numpy()[0] for c in confs]\n",
    "    fine, mid, coarse = confs\n",
    "    \n",
    "    # fuse the confidence maps (for visualization only)\n",
    "    fused = (fine*mid*coarse)**(1/3)\n",
    "\n",
    "    if preview:\n",
    "        plot_images([image, fine, mid, coarse, fused], **args)\n",
    "    else:\n",
    "        plot_images([image, fused], **args)\n",
    "    \n",
    "    # add image as background\n",
    "    [a.images[0].set(alpha=0.8) for a in plt.gcf().axes[1:]];\n",
    "    [a.imshow(image, extent=a.images[0]._extent, zorder=-1) for a in plt.gcf().axes[1:]];\n",
    "    \n",
    "    if not preview:\n",
    "        writer.add_frame()\n",
    "\n",
    "if not preview:\n",
    "    writer.to_video(video, fps=3, crf=23)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "display_video(video)"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3 (ipykernel)",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.7.8"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 4
}
